import torch.nn as nn
import torch
from ..op import block_linear_custom
from .basiclinear import BasicLinear


class BlockLinear(BasicLinear):

    def __init__(
        self,
        in_features,
        out_features,
        bias,
        return_bias,
        config,
        init_config,
        device="cuda",
    ):
        super().__init__(
            in_features, out_features, bias, return_bias, config, init_config, device
        )
        self.rank = config["rank"]
        self.nblocks = config["nblocks"]
        assert self.in_features % self.nblocks == 0
        assert self.rank % self.nblocks == 0
        self.blkdiag = nn.Parameter(
            torch.empty(
                self.nblocks,
                self.rank // self.nblocks,
                self.in_features // self.nblocks,
                device=device,
            )
        )
        self.lr = nn.Parameter(torch.empty(self.out_features, self.rank, device=device))

        self._init_weights()
        self.post_init()

    def get_weights(
        self,
    ):
        return [self.blkdiag, self.lr]

    @torch.no_grad()
    def post_init(
        self,
    ):
        if self.config.init.post_init == "ortho":
            for i in range(self.nblocks):
                U, S, Vh = torch.linalg.svd(self.blkdiag.data[i], full_matrices=False)
                self.blkdiag.data[i] = torch.mm(U, Vh)
            U, S, Vh = torch.linalg.svd(self.lr.data, full_matrices=False)
            self.lr.data = torch.mm(U, Vh)
        # init guide linear
        if hasattr(self, "guide_linear"):
            self.guide_linear.data = torch.mm(
                self.lr.data, torch.block_diag(*torch.unbind(self.blkdiag.data, dim=0))
            )

    @torch.no_grad
    def old_frobgrad(self, wd=0.0):
        # we find that decay the product does not have clear benefits but brings additional computation, so deprecate this one here
        if wd:
            if self.config.guide and self.gamma:
                self.guide_linear.data -= wd * self.gamma * self.guide_linear.data
                wd *= 1.0 - self.gamma
                self.gamma -= self.gamma_decay
                self.gamma = max(self.gamma, torch.tensor(0.0).cuda())

            U = self.lr.data
            Vh = torch.block_diag(*torch.unbind(self.blkdiag.data, dim=0))
            tmp1 = torch.chain_matmul(U, Vh, Vh.T)
            tmp2 = torch.chain_matmul(U.T, U, Vh)
            self.lr.data -= wd * tmp1
            in_blksz = self.blkdiag.shape[2]
            out_blksz = self.blkdiag.shape[1]
            for i in range(self.nblocks):
                self.blkdiag.data[i] -= (
                    wd
                    * tmp2[
                        i * out_blksz : (i + 1) * out_blksz,
                        i * in_blksz : (i + 1) * in_blksz,
                    ]
                )

    def forward(self, input):
        out = block_linear_custom(input, self.blkdiag, self.lr)
        return self.forward_guide_layer(input, out)

    def extra_repr(self) -> str:
        return f"blockdiag1={self.blkdiag.shape}, linear={self.lr.shape}, bias={self.bias is not None}, guide={self.training_config.enabled}"
